/*
 * Created on Oct 16, 2006
 */
package edu.swri.swiftvis.filters;

import java.awt.BorderLayout;
import java.awt.GridLayout;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JPanel;

import edu.swri.swiftvis.BooleanFormula;
import edu.swri.swiftvis.DataElement;
import edu.swri.swiftvis.DataFormula;
import edu.swri.swiftvis.GraphElement;
import edu.swri.swiftvis.util.DisjointSets;
import edu.swri.swiftvis.util.EditableBoolean;
import edu.swri.swiftvis.util.EditableString;
import edu.swri.swiftvis.util.KDTree;
import edu.swri.swiftvis.util.ListOfPropertyComponents;

/**
 * This filter will find clusters of elements as distributed through some space.  Each
 * element will be labeled with the cluster number it is in, the number of elements in
 * that cluster, etc.  Various parameters allow users to select how the clusters are
 * found.
 * 
 * Allow them to enter formulas for coordinates.  I could also allow a choice or a
 * formula for picking the full distance formula.  I will have the prime elements be
 * the second one.  They won't be able to use group selections, but it should work
 * well other than that.
 * 
 * Options:
 * Compare groups or compare individuals
 * For comparing groups you have to be able to specify group variables.  Might limit to summations and weighted averages.
 * Spatial needs formulas for dimensions and a search radius formula.
 * For first pass consider a button that sets things up for gravity.
 * 
 * The output has two streams per input stream.  The first copies all elements and adds a parameter telling which group it is in.
 * The second has the groups in order, p[0]=group number, and the group variables as well for the final groupings.
 * 
 * @author Mark Lewis
 */
public class ClusterFilter extends AbstractMultipleSourceFilter {
    public ClusterFilter() {}
    
    private ClusterFilter(ClusterFilter c,List<GraphElement> l) {
        super(c,l);
        compareClusters=new EditableBoolean(c.compareClusters.getValue());
        useSpatialTree=new EditableBoolean(c.useSpatialTree.getValue());
        searchDistance=new DataFormula(c.searchDistance);
        mergeCriteria=new BooleanFormula(c.mergeCriteria);
        for(VariableDefinition vd:c.spatialValues) {
            spatialValues.add(new VariableDefinition(vd));
        }
        for(VariableDefinition vd:c.generalValues) {
            generalValues.add(new VariableDefinition(vd));
        }
        for(ClusterVariableDefinition vd:c.clusterValues) {
            clusterValues.add(new ClusterVariableDefinition(vd));
        }
    }

    @Override
    protected void redoAllElements() {
        sizeDataVectToInputStreams();
        for(int s=0; s<getSource(0).getNumStreams(); ++s) {
            ClusterInfo[] ci=new ClusterInfo[getSource(0).getNumElements(s)];
            int[] range=mergeCriteria.getSafeElementRange(this,s);
            for(int i=0; i<spatialValues.size(); ++i) {
                DataFormula.mergeSafeElementRanges(range,spatialValues.get(i).value.getSafeElementRange(this,s));
            }
            for(int i=0; i<generalValues.size(); ++i) {
                DataFormula.mergeSafeElementRanges(range,generalValues.get(i).value.getSafeElementRange(this,s));
            }
            for(int i=0; i<clusterValues.size(); ++i) {
                DataFormula.mergeSafeElementRanges(range,clusterValues.get(i).value.getSafeElementRange(this,s));
                DataFormula.mergeSafeElementRanges(range,clusterValues.get(i).weight.getSafeElementRange(this,s));
            }
            DataFormula.checkRangeSafety(range,this);
            for(int i=range[0]; i<range[1]; ++i) {
                ci[i]=new ClusterInfo(this,s,i);
            }
            DisjointSets ds=new DisjointSets(ci.length);
            KDTree<CFTreePoint> tree=null;
            List<CFTreePoint> spatialPoints=new ArrayList<CFTreePoint>();
            if(useSpatialTree.getValue()) {
                if(statusLabel!=null) statusLabel.setText("Building tree");
                double[] vals=new double[spatialValues.size()];
                for(int i=range[0]; i<range[1]; ++i) {
                    for(int j=0; j<vals.length; ++j) {
                        vals[j]=spatialValues.get(j).value.valueOf(this,s,i);
                    }
                    spatialPoints.add(new CFTreePoint(i,vals));
                }
                double[] scaleArray=new double[vals.length];
                for(int i=0; i<scaleArray.length; ++i) scaleArray[i]=1;
                tree=new KDTree<CFTreePoint>(spatialPoints,vals.length,3,scaleArray);
            }
            Map<String,Double> varHash=new HashMap<String,Double>();
            int mergeCount=0;
            int pass=0;
            if(compareClusters.getValue()) {
                boolean flag=true;
                while(flag && mergeCount<range[1]-range[0]-1) {
                    flag=false;
                    pass++;
                    for(int i=range[0]; i<range[1]; ++i) {
                        if(statusLabel!=null && i%10000==0) statusLabel.setText("Pass="+pass+" i="+i);
                        if(i==ds.findSet(i)) {
                            fillGroupVariables(varHash,ci[i],"i");
                            if(useSpatialTree.getValue()) {
                                fillSpatialVariables(varHash,spatialPoints.get(i-range[0]),"i");
                                double distVal=searchDistance.valueOf(this,s,i,null,varHash);
                                double[] dist=new double[spatialValues.size()];
                                for(int j=0; j<dist.length; ++j) dist[j]=distVal;
                                List<CFTreePoint> neighbors=tree.getNearPoints(spatialPoints.get(i-range[0]),dist);
                                for(CFTreePoint point:neighbors) {
                                    int j=point.index;
                                    if(j==ds.findSet(j) && j!=i) {
                                        fillGroupVariables(varHash,ci[j],"j");
                                        fillSpatialVariables(varHash,point,"j");
                                        int[] ja={j};
                                        fillOtherVariables(varHash,i,ja,s);
                                        if(checkMergeBetween(i,ja,s,varHash,ds)) {
                                            flag=true;
                                            ci[i].mergeWith(ci[j]);
                                            mergeCount++;
                                        }
                                    }
                                }
                            } else {
                                for(int j=i+1; j<range[1]; ++j) {
                                    if(j==ds.findSet(j)) {
                                        fillGroupVariables(varHash,ci[j],"j");
                                        int[] ja={j};
                                        fillOtherVariables(varHash,i,ja,s);
                                        if(checkMergeBetween(i,ja,s,varHash,ds)) {
                                            flag=true;
                                            ci[i].mergeWith(ci[j]);
                                            mergeCount++;
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            } else {
                for(int i=range[0]; i<range[1] && mergeCount<range[1]-range[0]-1; ++i) {
                    if(statusLabel!=null && i%10000==0) statusLabel.setText("i="+i);
                    if(useSpatialTree.getValue()) {
                        fillSpatialVariables(varHash,spatialPoints.get(i-range[0]),"i");
                        double distVal=searchDistance.valueOf(this,s,i,null,varHash);
                        double[] dist=new double[spatialValues.size()];
                        for(int j=0; j<dist.length; ++j) dist[j]=distVal;
                        List<CFTreePoint> neighbors=tree.getNearPoints(spatialPoints.get(i-range[0]),dist);
                        for(CFTreePoint point:neighbors) {
                            if(ds.findSet(point.index)!=ds.findSet(i)) {
                                fillSpatialVariables(varHash,point,"j");
                                int[] ja={point.index};
                                fillOtherVariables(varHash,i,ja,s);
                                if(checkMergeBetween(i,ja,s,varHash,ds)) mergeCount++;
                            }
                        }
                    } else {
                        for(int j=i+1; j<range[1]; ++j) {
                            if(ds.findSet(j)!=ds.findSet(i)) {
                                int[] ja={j};
                                fillOtherVariables(varHash,i,ja,s);
                                if(checkMergeBetween(i,ja,s,varHash,ds)) mergeCount++;
                            }
                        }
                    }
                }
            }
            Set<Integer> groups=new HashSet<Integer>();
            for(int i=range[0]; i<range[1]; ++i) {
                int g=ds.findSet(i);
                dataVect.get(2*s).add(new DataElement(getSource(0).getElement(i,s),g));
                groups.add(g);
            }
            int[] params=new int[1];
            float[] values=new float[clusterValues.size()]; 
            for(int g:groups) {
                params[0]=g;
                for(int i=0; i<values.length; ++i) {
                    values[i]=(float)clusterValues.get(i).calcValue(ci[g].valueSum[i],ci[g].weightSum[i]);
                }
                dataVect.get(2*s+1).add(new DataElement(params,values));
            }
            if(statusLabel!=null) statusLabel.setText("Not Processing");
        }
    }
    
    private void fillSpatialVariables(Map<String,Double> varHash,CFTreePoint point,String suffix) {
        for(int i=0; i<point.v.length; ++i) {
            varHash.put(spatialValues.get(i).name.toString()+suffix,point.v[i]);
        }
    }
    
    private void fillGroupVariables(Map<String,Double> varHash,ClusterInfo gi,String suffix) {
        for(int i=0; i<clusterValues.size(); ++i) {
            double val=clusterValues.get(i).calcValue(gi.valueSum[i],gi.weightSum[i]);
            varHash.put(clusterValues.get(i).name.getValue()+suffix,val);
        }
    }
    
    private void fillOtherVariables(Map<String,Double> varHash,int i,int[] j,int s) {
        for(VariableDefinition vd:generalValues) {
            varHash.put(vd.name.getValue(),vd.value.valueOf(this,s,i,j,varHash));
        }
    }
    
    private boolean checkMergeBetween(int i,int[] j,int s,Map<String,Double> varHash,DisjointSets ds) {
        if(mergeCriteria.valueOf(this,s,i,j,varHash)) {
            ds.union(i,j[0]);
            return true;
        }
        return false;
    }
    
    @Override
    protected void setupSpecificPanelProperties() {
        JPanel panel=new JPanel(new BorderLayout());
        JPanel northPanel=new JPanel(new GridLayout(2,1));
        JButton gravityButton=new JButton("Setup for Gravity");
        northPanel.add(gravityButton);
        northPanel.add(compareClusters.getCheckBox("Compare Clusters?",null));
        panel.add(northPanel,BorderLayout.NORTH);
        final ListOfPropertyComponents<ClusterVariableDefinition> centerPanel=new ListOfPropertyComponents<ClusterVariableDefinition>(clusterValues, "Cluster Values",ClusterVariableDefinition.class);
        panel.add(centerPanel,BorderLayout.CENTER);
        JPanel southPanel=new JPanel(new GridLayout(3,1));
        southPanel.add(mergeCriteria.getLabeledTextField("Merge Criteria",null));
        statusLabel=new JLabel("Not Processing");
        southPanel.add(statusLabel);
        JButton button=new JButton("Propagate changes");
        button.addActionListener(new ActionListener() {
            @Override public void actionPerformed(ActionEvent e) { abstractRedoAllElements(); }
        });
        southPanel.add(button);
        panel.add(southPanel,BorderLayout.SOUTH);
        propPanel.addTab("Basic and Clusters",panel);

        panel=new JPanel(new BorderLayout());
        northPanel=new JPanel(new GridLayout(2,1));
        northPanel.add(useSpatialTree.getCheckBox("Use Spatial Tree?",null));
        northPanel.add(searchDistance.getLabeledTextField("Search Distance",null));
        panel.add(northPanel,BorderLayout.NORTH);
        JPanel middlePanel=new JPanel(new GridLayout(2,1));
        final ListOfPropertyComponents<VariableDefinition> spatialPanel=new ListOfPropertyComponents<VariableDefinition>(spatialValues,"Spatial Formulas",VariableDefinition.class);
        middlePanel.add(spatialPanel);
        final ListOfPropertyComponents<VariableDefinition> generalPanel=new ListOfPropertyComponents<VariableDefinition>(generalValues,"General Formulas",VariableDefinition.class);
        middlePanel.add(generalPanel);
        panel.add(middlePanel,BorderLayout.CENTER);
        button=new JButton("Propagate changes");
        button.addActionListener(new ActionListener() {
            @Override public void actionPerformed(ActionEvent e) { abstractRedoAllElements(); }
        });
        panel.add(button,BorderLayout.SOUTH);
        propPanel.addTab("Spatial and Variables",panel);
        
        gravityButton.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent arg0) {
                compareClusters.setValue(true);
                useSpatialTree.setValue(true);
                spatialValues.clear();
                spatialValues.add(new VariableDefinition("x","v[0]"));
                spatialValues.add(new VariableDefinition("y","v[1]"));
                spatialValues.add(new VariableDefinition("z","v[2]"));
                spatialPanel.dataUpdated();
                clusterValues.clear();
                clusterValues.add(new ClusterVariableDefinition("mass","4/3*PI*v[6]*v[6]*v[6]","1",0));
                clusterValues.add(new ClusterVariableDefinition("cx","v[0]","v[6]*v[6]*v[6]",1));
                clusterValues.add(new ClusterVariableDefinition("cy","v[1]","v[6]*v[6]*v[6]",1));
                clusterValues.add(new ClusterVariableDefinition("cz","v[2]","v[6]*v[6]*v[6]",1));
                clusterValues.add(new ClusterVariableDefinition("cvx","v[3]","v[6]*v[6]*v[6]",1));
                clusterValues.add(new ClusterVariableDefinition("cvy","v[4]","v[6]*v[6]*v[6]",1));
                clusterValues.add(new ClusterVariableDefinition("cvz","v[5]","v[6]*v[6]*v[6]",1));
                centerPanel.dataUpdated();
                generalValues.clear();
                generalValues.add(new VariableDefinition("dx","cxi-cxj"));
                generalValues.add(new VariableDefinition("dy","cyi-cyj"));
                generalValues.add(new VariableDefinition("dz","czi-czj"));
                generalValues.add(new VariableDefinition("dist","sqrt(dx*dx+dy*dy+dz*dz)"));
                generalValues.add(new VariableDefinition("cmvx","(cvxi*massi+cvxj*massj)/(massi+massj)"));
                generalValues.add(new VariableDefinition("cmvy","(cvyi*massi+cvyj*massj)/(massi+massj)"));
                generalValues.add(new VariableDefinition("cmvz","(cvzi*massi+cvzj*massj)/(massi+massj)"));
                generalValues.add(new VariableDefinition("dvxi","cvxi-cmvx"));
                generalValues.add(new VariableDefinition("dvyi","cvyi-cmvy"));
                generalValues.add(new VariableDefinition("dvzi","cvzi-cmvz"));
                generalValues.add(new VariableDefinition("dvxj","cvxj-cmvx"));
                generalValues.add(new VariableDefinition("dvyj","cvyj-cmvy"));
                generalValues.add(new VariableDefinition("dvzj","cvzj-cmvz"));
                generalValues.add(new VariableDefinition("visqr","dvxi*dvxi+dvyi*dvyi+dvzi*dvzi"));
                generalValues.add(new VariableDefinition("vjsqr","dvxj*dvxj+dvyj*dvyj+dvzj*dvzj"));
                generalValues.add(new VariableDefinition("energy","0.5*(massi*visqr+massj*vjsqr)-massi*massj/dist"));
                generalPanel.dataUpdated();
                searchDistance.setFormula("cbrt(massi/3)*1.6");
                mergeCriteria.setFormula("energy<0");
            }
        });
    }

    @Override
    protected boolean doingInThreads() {
        return false;
    }

    public int getNumParameters(int stream) {
        if(getNumSources()<1) return 0;
        int inStream=stream/2;
        if(stream%2==0) {
            return getSource(0).getNumParameters(inStream)+1;
        } else {
            return 1;
        }
    }

    public String getParameterDescription(int stream, int which) {
        if(getNumSources()<1) return "None";
        int inStream=stream/2;
        if(stream%2==0) {
            if(which<getSource(0).getNumParameters(inStream)) {
                return getSource(0).getParameterDescription(inStream,which);
            } else {
                return "Group Number";
            }
        } else {
            return "Group Number";
        }
    }

    public int getNumValues(int stream) {
        if(getNumSources()<1) return 0;
        int inStream=stream/2;
        if(stream%2==0) {
            return getSource(0).getNumValues(inStream);
        } else {
            return clusterValues.size();
        }
    }

    public String getValueDescription(int stream, int which) {
        if(getNumSources()<1) return "None";
        int inStream=stream/2;
        if(stream%2==0) {
            return getSource(0).getParameterDescription(inStream,which);
        } else {
            return clusterValues.get(which).name.getValue();
        }
    }

    public static String getTypeDescription() {
        return "Cluster Filter";
    }

    public String getDescription() {
        return "Cluster Filter";
    }

    public GraphElement copy(List<GraphElement> l) {
        return new ClusterFilter(this,l);
    }
    
    protected void sizeDataVectToInputStreams() {
        if(dataVect.size()>inputVector.get(0).getNumStreams()) dataVect.clear();
        while(dataVect.size()<inputVector.get(0).getNumStreams()*2) dataVect.add(new ArrayList<DataElement>());
    }
    
    private EditableBoolean compareClusters=new EditableBoolean(false);
    private EditableBoolean useSpatialTree=new EditableBoolean(false);
    private List<VariableDefinition> spatialValues=new ArrayList<VariableDefinition>();
    private DataFormula searchDistance=new DataFormula("1");
    private List<ClusterVariableDefinition> clusterValues=new ArrayList<ClusterVariableDefinition>();
    private List<VariableDefinition> generalValues=new ArrayList<VariableDefinition>();
    private BooleanFormula mergeCriteria=new BooleanFormula("1=1");

    private transient JLabel statusLabel;
    private static final long serialVersionUID = -3267078395956548632L;
    
    public static class ClusterVariableDefinition implements Serializable,ListOfPropertyComponents.PropertiedComponent {
        public ClusterVariableDefinition() {}
        public ClusterVariableDefinition(String n,String v,String w,int type) {
            name.setValue(n);
            value.setFormula(v);
            weight.setFormula(w);
            combineStyle=type;
        }
        public ClusterVariableDefinition(ClusterVariableDefinition c) {
            name.setValue(c.name.getValue());
            value.setFormula(c.value.getFormula());
            weight.setFormula(c.weight.getFormula());
            combineStyle=c.combineStyle;
        }
        public double calcValue(double v,double w) {
            if(combineStyle==0) {
                return v;
            } else {
                return v/w;
            }
        }
        public JPanel getPropertiesPanel() {
            if(propPanel==null) {
                propPanel=new JPanel(new BorderLayout());
                JPanel northPanel=new JPanel(new GridLayout(4,1));
                northPanel.add(name.getLabeledTextField("Variable Name", null));
                northPanel.add(value.getLabeledTextField("Value Formula", null));
                northPanel.add(weight.getLabeledTextField("Weight Formula", null));
                String[] combineOptions={"Sum","Weighted Average"};
                final JComboBox comboBox=new JComboBox(combineOptions);
                comboBox.setSelectedIndex(combineStyle);
                comboBox.addActionListener(new ActionListener() {
                    @Override
                    public void actionPerformed(ActionEvent e) {
                        int index=comboBox.getSelectedIndex();
                        if(index>=0) combineStyle=index;
                    }
                });
                northPanel.add(comboBox);
                propPanel.add(northPanel,BorderLayout.NORTH);
            }
            return propPanel;
        }
        public String toString() {
            return name.getValue()+" : "+value.getFormula()+" : "+weight.getFormula();
        }
        private EditableString name=new EditableString("x");
        private DataFormula value=new DataFormula("0");
        private DataFormula weight=new DataFormula("1");
        private int combineStyle;
        private transient JPanel propPanel;
        private static final long serialVersionUID = -827917979571505970L;
    }
    
    public static class VariableDefinition implements Serializable,ListOfPropertyComponents.PropertiedComponent {
        public VariableDefinition() {}
        public VariableDefinition(String n,String v) {
            name.setValue(n);
            value.setFormula(v);
        }
        public VariableDefinition(VariableDefinition c) {
            name.setValue(c.name.getValue());
            value.setFormula(c.value.getFormula());
        }
        public JPanel getPropertiesPanel() {
            if(propPanel==null) {
                propPanel=new JPanel(new BorderLayout());
                JPanel northPanel=new JPanel(new GridLayout(2,1));
                northPanel.add(name.getLabeledTextField("Variable Name", null));
                northPanel.add(value.getLabeledTextField("Variable Formula", null));
                propPanel.add(northPanel,BorderLayout.NORTH);
            }
            return propPanel;
        }
        public String toString() {
            return name.getValue()+" : "+value.getFormula();
        }
        private EditableString name=new EditableString("x");
        private DataFormula value=new DataFormula("0");
        private transient JPanel propPanel;
        private static final long serialVersionUID = 6231037736320240932L;
    }
    
    private static class ClusterInfo {
        public ClusterInfo(ClusterFilter cf,int s,int i) {
            valueSum=new double[cf.clusterValues.size()];
            weightSum=new double[cf.clusterValues.size()];
            for(int j=0; j<valueSum.length; ++j) {
                weightSum[j]=cf.clusterValues.get(j).weight.valueOf(cf,s,i);
                valueSum[j]=cf.clusterValues.get(j).value.valueOf(cf,s,i)*weightSum[j];
            }
        }
        public void mergeWith(ClusterInfo ci) {
            for(int j=0; j<valueSum.length; ++j) {
                valueSum[j]+=ci.valueSum[j];
                weightSum[j]+=ci.weightSum[j];
                ci.valueSum[j]=valueSum[j];
                ci.weightSum[j]=weightSum[j];
            }
        }
        private double[] valueSum;
        private double[] weightSum;
    }
    
    private static class CFTreePoint implements KDTree.TreePoint {
        public CFTreePoint(int i,double[] vals) {
            index=i;
            v=Arrays.copyOf(vals,vals.length);
        }
        @Override
        public double getVal(int val) {
            return v[val];
        }
        private int index;
        private double[] v;
    }
}
